{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training a differentially private LSTM model for name classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will build a differentially-private LSTM model to classify names to their source languages, which is the same task as in the tutorial **NLP From Scratch** (https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html). Since the objective of this tutorial is to demonstrate the effective use of an LSTM with privacy guarantees, we will be utilizing it in place of the bare-bones RNN model defined in the original tutorial. Specifically, we use the `DPLSTM` module from `opacus.layers.dp_lstm` to facilitate the calculation of the per-example gradients, which are utilized in the addition of noise during the application of differential privacy. `DPLSTM` has the same API and functionality as the `nn.LSTM`, with some restrictions (ex. we currently support single layers, the full list is given below).  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let us download the dataset of names and their associated language labels as given in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html. We train our differentially-private LSTM on the same dataset as in that tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and extracting ...\n",
      "Completed!\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "import os\n",
    "import requests\n",
    "\n",
    "\n",
    "NAMES_DATASET_URL = \"https://download.pytorch.org/tutorial/data.zip\"\n",
    "DATA_DIR = \"names\"\n",
    "\n",
    "import zipfile\n",
    "import urllib\n",
    "\n",
    "def download_and_extract(dataset_url, data_dir):\n",
    "    print(\"Downloading and extracting ...\")\n",
    "    filename = \"data.zip\"\n",
    "\n",
    "    urllib.request.urlretrieve(dataset_url, filename)\n",
    "    with zipfile.ZipFile(filename) as zip_ref:\n",
    "        zip_ref.extractall(data_dir)\n",
    "    os.remove(filename)\n",
    "    print(\"Completed!\")\n",
    "\n",
    "download_and_extract(NAMES_DATASET_URL, DATA_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Italian.txt', 'Arabic.txt', 'English.txt', 'German.txt', 'French.txt', 'Spanish.txt', 'Greek.txt', 'Dutch.txt', 'Korean.txt', 'Portuguese.txt', 'Japanese.txt', 'Polish.txt', 'Irish.txt', 'Chinese.txt', 'Russian.txt', 'Czech.txt', 'Vietnamese.txt', 'Scottish.txt']\n"
     ]
    }
   ],
   "source": [
    "names_folder = os.path.join(DATA_DIR, 'data', 'names')\n",
    "all_filenames = []\n",
    "\n",
    "for language_file in os.listdir(names_folder):\n",
    "    all_filenames.append(os.path.join(names_folder, language_file))\n",
    "    \n",
    "print(os.listdir(names_folder))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class CharByteEncoder(nn.Module):\n",
    "    \"\"\"\n",
    "    This encoder takes a UTF-8 string and encodes its bytes into a Tensor. It can also\n",
    "    perform the opposite operation to check a result.\n",
    "    Examples:\n",
    "    >>> encoder = CharByteEncoder()\n",
    "    >>> t = encoder('Ślusàrski')  # returns tensor([256, 197, 154, 108, 117, 115, 195, 160, 114, 115, 107, 105, 257])\n",
    "    >>> encoder.decode(t)  # returns \"<s>Ślusàrski</s>\"\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.start_token = \"<s>\"\n",
    "        self.end_token = \"</s>\"\n",
    "        self.pad_token = \"<pad>\"\n",
    "\n",
    "        self.start_idx = 256\n",
    "        self.end_idx = 257\n",
    "        self.pad_idx = 258\n",
    "\n",
    "    def forward(self, s: str, pad_to=0) -> torch.LongTensor:\n",
    "        \"\"\"\n",
    "        Encodes a string. It will append a start token <s> (id=self.start_idx) and an end token </s>\n",
    "        (id=self.end_idx).\n",
    "        Args:\n",
    "            s: The string to encode.\n",
    "            pad_to: If not zero, pad by appending self.pad_idx until string is of length `pad_to`.\n",
    "                Defaults to 0.\n",
    "        Returns:\n",
    "            The encoded LongTensor of indices.\n",
    "        \"\"\"\n",
    "        encoded = s.encode()\n",
    "        n_pad = pad_to - len(encoded) if pad_to > len(encoded) else 0\n",
    "        return torch.LongTensor(\n",
    "            [self.start_idx]\n",
    "            + [c for c in encoded]  # noqa\n",
    "            + [self.end_idx]\n",
    "            + [self.pad_idx for _ in range(n_pad)]\n",
    "        )\n",
    "\n",
    "    def decode(self, char_ids_tensor: torch.LongTensor) -> str:\n",
    "        \"\"\"\n",
    "        The inverse of `forward`. Keeps the start, end, and pad indices.\n",
    "        \"\"\"\n",
    "        char_ids = char_ids_tensor.cpu().detach().tolist()\n",
    "\n",
    "        out = []\n",
    "        buf = []\n",
    "        for c in char_ids:\n",
    "            if c < 256:\n",
    "                buf.append(c)\n",
    "            else:\n",
    "                if buf:\n",
    "                    out.append(bytes(buf).decode())\n",
    "                    buf = []\n",
    "                if c == self.start_idx:\n",
    "                    out.append(self.start_token)\n",
    "                elif c == self.end_idx:\n",
    "                    out.append(self.end_token)\n",
    "                elif c == self.pad_idx:\n",
    "                    out.append(self.pad_token)\n",
    "\n",
    "        if buf:  # in case some are left\n",
    "            out.append(bytes(buf).decode())\n",
    "        return \"\".join(out)\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"\n",
    "        The length of our encoder space. This is fixed to 256 (one byte) + 3 special chars\n",
    "        (start, end, pad).\n",
    "        Returns:\n",
    "            259\n",
    "        \"\"\"\n",
    "        return 259"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training / Validation Set Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "def padded_collate(batch, padding_idx=0):\n",
    "    x = pad_sequence(\n",
    "        [elem[0] for elem in batch], batch_first=True, padding_value=padding_idx\n",
    "    )\n",
    "    y = torch.stack([elem[1] for elem in batch]).long()\n",
    "\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "class NamesDataset(Dataset):\n",
    "    def __init__(self, root):\n",
    "        self.root = Path(root)\n",
    "\n",
    "        self.labels = list({langfile.stem for langfile in self.root.iterdir()})\n",
    "        self.labels_dict = {label: i for i, label in enumerate(self.labels)}\n",
    "        self.encoder = CharByteEncoder()\n",
    "        self.samples = self.construct_samples()\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.samples[i]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "\n",
    "    def construct_samples(self):\n",
    "        samples = []\n",
    "        for langfile in self.root.iterdir():\n",
    "            label_name = langfile.stem\n",
    "            label_id = self.labels_dict[label_name]\n",
    "            with open(langfile, \"r\") as fin:\n",
    "                for row in fin:\n",
    "                    samples.append(\n",
    "                        (self.encoder(row.strip()), torch.tensor(label_id).long())\n",
    "                    )\n",
    "        return samples\n",
    "\n",
    "    def label_count(self):\n",
    "        cnt = Counter()\n",
    "        for _x, y in self.samples:\n",
    "            label = self.labels[int(y)]\n",
    "            cnt[label] += 1\n",
    "        return cnt\n",
    "\n",
    "\n",
    "VOCAB_SIZE = 256 + 3  # 256 alternatives in one byte, plus 3 special characters.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We split the dataset into a 80-20 split for training and validation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16059 samples for training, 4015 for testing\n"
     ]
    }
   ],
   "source": [
    "secure_mode = False\n",
    "train_split = 0.8\n",
    "test_every = 5\n",
    "batch_size = 800\n",
    "\n",
    "ds = NamesDataset(names_folder)\n",
    "train_len = int(train_split * len(ds))\n",
    "test_len = len(ds) - train_len\n",
    "\n",
    "print(f\"{train_len} samples for training, {test_len} for testing\")\n",
    "\n",
    "train_ds, test_ds = torch.utils.data.random_split(ds, [train_len, test_len])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train_ds,\n",
    "    batch_size=batch_size,\n",
    "    pin_memory=True,\n",
    "    collate_fn=padded_collate,\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_ds,\n",
    "    batch_size=2 * batch_size,\n",
    "    shuffle=False,\n",
    "    pin_memory=True,\n",
    "    collate_fn=padded_collate,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After splitting the dataset into a training and a validation set, we now have to convert the data into a numeric form suitable for training the LSTM model. For each name, we set a maximum sequence length of 15, and if a name is longer than the threshold, we truncate it (this rarely happens in this dataset!). If a name is smaller than the threshold, we add a dummy `#` character to pad it to the desired length. We also batch the names in the dataset and set a batch size of 256 for all the experiments in this tutorial. The function `line_to_tensor()` returns a tensor of shape [15, 256] where each element is the index (in `all_letters`) of the corresponding character."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training/Evaluation Cycle "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training and the evaluation functions `train()` and `test()` are defined below. During the training loop, the per-example gradients are computed and the parameters are updated subsequent to gradient clipping (to bound their sensitivity) and addition of noise.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from statistics import mean\n",
    "\n",
    "def train(model, criterion, optimizer, train_loader, epoch, privacy_engine, device=\"cuda:0\"):\n",
    "    accs = []\n",
    "    losses = []\n",
    "    for x, y in train_loader:\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        logits = model(x)\n",
    "        loss = criterion(logits, y)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        preds = logits.argmax(-1)\n",
    "        n_correct = float(preds.eq(y).sum())\n",
    "        batch_accuracy = n_correct / len(y)\n",
    "\n",
    "        accs.append(batch_accuracy)\n",
    "        losses.append(float(loss))\n",
    "\n",
    "    printstr = (\n",
    "        f\"\\t Epoch {epoch}. Accuracy: {mean(accs):.6f} | Loss: {mean(losses):.6f}\"\n",
    "    )\n",
    "    if privacy_engine:\n",
    "        epsilon = privacy_engine.get_epsilon(delta)\n",
    "        printstr += f\" | (ε = {epsilon:.2f}, δ = {delta})\"\n",
    "\n",
    "    print(printstr)\n",
    "    return\n",
    "\n",
    "\n",
    "def test(model, test_loader, privacy_engine, device=\"cuda:0\"):\n",
    "    accs = []\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            preds = model(x).argmax(-1)\n",
    "            n_correct = float(preds.eq(y).sum())\n",
    "            batch_accuracy = n_correct / len(y)\n",
    "\n",
    "            accs.append(batch_accuracy)\n",
    "    printstr = \"\\n----------------------------\\n\" f\"Test Accuracy: {mean(accs):.6f}\"\n",
    "    if privacy_engine:\n",
    "        epsilon = privacy_engine.get_epsilon(delta)\n",
    "        printstr += f\" (ε = {epsilon:.2f}, δ = {delta})\"\n",
    "    print(printstr + \"\\n----------------------------\\n\")\n",
    "    return\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyper-parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two sets of hyper-parameters associated with this model. The first are hyper-parameters which we would expect in any machine learning training, such as the learning rate and batch size. The second set are related to the privacy engine, where for example we define the amount of noise added to the gradients (`noise_multiplier`), and the maximum L2 norm to which the per-sample gradients are clipped (`max_grad_norm`). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training hyper-parameters\n",
    "epochs = 50\n",
    "learning_rate = 2.0\n",
    "\n",
    "# Privacy engine hyper-parameters\n",
    "max_per_sample_grad_norm = 1.5\n",
    "delta = 8e-5\n",
    "epsilon = 12.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define the name classification model in the cell below. Note that it is a simple char-LSTM classifier, where the input characters are passed through an `nn.Embedding` layer, and are subsequently input to the DPLSTM. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from opacus.layers import DPLSTM\n",
    "\n",
    "class CharNNClassifier(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        embedding_size,\n",
    "        hidden_size,\n",
    "        output_size,\n",
    "        num_lstm_layers=1,\n",
    "        bidirectional=False,\n",
    "        vocab_size=VOCAB_SIZE,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding_size = embedding_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.output_size = output_size\n",
    "        self.vocab_size = vocab_size\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_size)\n",
    "        self.lstm = DPLSTM(\n",
    "            embedding_size,\n",
    "            hidden_size,\n",
    "            num_layers=num_lstm_layers,\n",
    "            bidirectional=bidirectional,\n",
    "            batch_first=True,\n",
    "        )\n",
    "        self.out_layer = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        x = self.embedding(x)  # -> [B, T, D]\n",
    "        x, _ = self.lstm(x, hidden)  # -> [B, T, H]\n",
    "        x = x[:, -1, :]  # -> [B, H]\n",
    "        x = self.out_layer(x)  # -> [B, C]\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now proceed to instantiate the objects (privacy engine, model and optimizer) for our differentially-private LSTM training.  However, the `nn.LSTM` is replaced with a `DPLSTM` module which enables us to calculate per-example gradients. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the device to run on a GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Define classifier parameters\n",
    "embedding_size = 64\n",
    "hidden_size = 128  # Number of neurons in hidden layer after LSTM\n",
    "n_lstm_layers = 1\n",
    "bidirectional_lstm = False\n",
    "\n",
    "model = CharNNClassifier(\n",
    "    embedding_size,\n",
    "    hidden_size,\n",
    "    len(ds.labels),\n",
    "    n_lstm_layers,\n",
    "    bidirectional_lstm,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the privacy engine, optimizer and loss criterion for the problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opacus import PrivacyEngine\n",
    "privacy_engine = PrivacyEngine(secure_mode=secure_mode)\n",
    "\n",
    "model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(\n",
    "    module=model,\n",
    "    optimizer=optimizer,\n",
    "    data_loader=train_loader,\n",
    "    max_grad_norm=max_per_sample_grad_norm,\n",
    "    target_delta=delta,\n",
    "    target_epsilon=epsilon,\n",
    "    epochs=epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the name classifier with privacy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can start training! We will be training for 50 epochs iterations (where each epoch corresponds to a pass over the whole dataset). We will be reporting the privacy epsilon every `test_every` epoch. We will also benchmark this differentially-private model against a model without privacy and obtain almost identical performance. Further, the private model trained with Opacus incurs only minimal overhead in training time, with the differentially-private classifier only slightly slower (by a couple of minutes) than the non-private model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train stats: \n",
      "\n",
      "\t Epoch 0. Accuracy: 0.428835 | Loss: 2.220773\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.469154 (ε = 2.30, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 1. Accuracy: 0.472534 | Loss: 1.895850\n",
      "\t Epoch 2. Accuracy: 0.471778 | Loss: 1.893783\n",
      "\t Epoch 3. Accuracy: 0.459604 | Loss: 1.958717\n",
      "\t Epoch 4. Accuracy: 0.491896 | Loss: 1.782331\n",
      "\t Epoch 5. Accuracy: 0.540205 | Loss: 1.577036\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.559490 (ε = 4.16, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 6. Accuracy: 0.593796 | Loss: 1.456133\n",
      "\t Epoch 7. Accuracy: 0.616827 | Loss: 1.388250\n",
      "\t Epoch 8. Accuracy: 0.632560 | Loss: 1.345773\n",
      "\t Epoch 9. Accuracy: 0.639074 | Loss: 1.327238\n",
      "\t Epoch 10. Accuracy: 0.650502 | Loss: 1.316831\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.650821 (ε = 5.43, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 11. Accuracy: 0.649294 | Loss: 1.315323\n",
      "\t Epoch 12. Accuracy: 0.656350 | Loss: 1.288794\n",
      "\t Epoch 13. Accuracy: 0.656104 | Loss: 1.285352\n",
      "\t Epoch 14. Accuracy: 0.656424 | Loss: 1.283710\n",
      "\t Epoch 15. Accuracy: 0.666633 | Loss: 1.273102\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.667164 (ε = 6.51, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 16. Accuracy: 0.672707 | Loss: 1.247125\n",
      "\t Epoch 17. Accuracy: 0.680121 | Loss: 1.223817\n",
      "\t Epoch 18. Accuracy: 0.686456 | Loss: 1.214923\n",
      "\t Epoch 19. Accuracy: 0.694982 | Loss: 1.193048\n",
      "\t Epoch 20. Accuracy: 0.694282 | Loss: 1.184953\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.682519 (ε = 7.46, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 21. Accuracy: 0.701802 | Loss: 1.161172\n",
      "\t Epoch 22. Accuracy: 0.706358 | Loss: 1.166274\n",
      "\t Epoch 23. Accuracy: 0.722667 | Loss: 1.097268\n",
      "\t Epoch 24. Accuracy: 0.703950 | Loss: 1.185700\n",
      "\t Epoch 25. Accuracy: 0.720196 | Loss: 1.112226\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.707127 (ε = 8.33, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 26. Accuracy: 0.720644 | Loss: 1.115221\n",
      "\t Epoch 27. Accuracy: 0.708652 | Loss: 1.158104\n",
      "\t Epoch 28. Accuracy: 0.724744 | Loss: 1.119688\n",
      "\t Epoch 29. Accuracy: 0.733490 | Loss: 1.088846\n",
      "\t Epoch 30. Accuracy: 0.729441 | Loss: 1.089938\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.701941 (ε = 9.15, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 31. Accuracy: 0.731014 | Loss: 1.096586\n",
      "\t Epoch 32. Accuracy: 0.736907 | Loss: 1.065786\n",
      "\t Epoch 33. Accuracy: 0.733743 | Loss: 1.098627\n",
      "\t Epoch 34. Accuracy: 0.741741 | Loss: 1.064197\n",
      "\t Epoch 35. Accuracy: 0.742394 | Loss: 1.053995\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.720777 (ε = 9.93, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 36. Accuracy: 0.749420 | Loss: 1.034596\n",
      "\t Epoch 37. Accuracy: 0.748662 | Loss: 1.037211\n",
      "\t Epoch 38. Accuracy: 0.745869 | Loss: 1.061525\n",
      "\t Epoch 39. Accuracy: 0.751734 | Loss: 1.022538\n",
      "\t Epoch 40. Accuracy: 0.751194 | Loss: 1.028292\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.744636 (ε = 10.67, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 41. Accuracy: 0.754300 | Loss: 1.032082\n",
      "\t Epoch 42. Accuracy: 0.753252 | Loss: 1.017024\n",
      "\t Epoch 43. Accuracy: 0.755629 | Loss: 1.035767\n",
      "\t Epoch 44. Accuracy: 0.758195 | Loss: 1.029165\n",
      "\t Epoch 45. Accuracy: 0.751091 | Loss: 1.028669\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.739427 (ε = 11.38, δ = 8e-05)\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 46. Accuracy: 0.760692 | Loss: 0.995788\n",
      "\t Epoch 47. Accuracy: 0.763821 | Loss: 0.990309\n",
      "\t Epoch 48. Accuracy: 0.763423 | Loss: 0.997126\n",
      "\t Epoch 49. Accuracy: 0.767976 | Loss: 0.982944\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.752090 (ε = 11.93, δ = 8e-05)\n",
      "----------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"Train stats: \\n\")\n",
    "for epoch in range(epochs):\n",
    "    train(model, criterion, optimizer, train_loader, epoch, privacy_engine, device=device)\n",
    "    if test_every:\n",
    "        if epoch % test_every == 0:\n",
    "            test(model, test_loader, privacy_engine, device=device)\n",
    "\n",
    "test(model, test_loader, privacy_engine, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The differentially-private name classification model obtains a test accuracy of 0.75 with an epsilon of just under 12. This shows that we can achieve good accuracy on this task, with minimal loss of privacy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the name classifier without privacy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " We also run a comparison with a non-private model to see if the performance obtained with privacy is comparable to it. To do this, we keep the parameters such as learning rate and batch size the same, and only define a different instance of the model along with a separate optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_nodp = CharNNClassifier(\n",
    "    embedding_size,\n",
    "    hidden_size,\n",
    "    len(ds.labels),\n",
    "    n_lstm_layers,\n",
    "    bidirectional_lstm,\n",
    ").to(device)\n",
    "\n",
    "\n",
    "optimizer_nodp = torch.optim.SGD(model_nodp.parameters(), lr=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\t Epoch 0. Accuracy: 0.423231 | Loss: 1.957621\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.469154\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 1. Accuracy: 0.470835 | Loss: 1.850998\n",
      "\t Epoch 2. Accuracy: 0.461741 | Loss: 1.845881\n",
      "\t Epoch 3. Accuracy: 0.466039 | Loss: 1.848411\n",
      "\t Epoch 4. Accuracy: 0.470612 | Loss: 1.857506\n",
      "\t Epoch 5. Accuracy: 0.460152 | Loss: 1.845789\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.469154\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 6. Accuracy: 0.477714 | Loss: 1.775618\n",
      "\t Epoch 7. Accuracy: 0.518488 | Loss: 1.622382\n",
      "\t Epoch 8. Accuracy: 0.535421 | Loss: 1.565642\n",
      "\t Epoch 9. Accuracy: 0.545521 | Loss: 1.511846\n",
      "\t Epoch 10. Accuracy: 0.543908 | Loss: 1.514014\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.575170\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 11. Accuracy: 0.561950 | Loss: 1.454853\n",
      "\t Epoch 12. Accuracy: 0.605502 | Loss: 1.388555\n",
      "\t Epoch 13. Accuracy: 0.607155 | Loss: 1.367188\n",
      "\t Epoch 14. Accuracy: 0.615066 | Loss: 1.346803\n",
      "\t Epoch 15. Accuracy: 0.621913 | Loss: 1.332553\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.635465\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 16. Accuracy: 0.619772 | Loss: 1.314691\n",
      "\t Epoch 17. Accuracy: 0.629337 | Loss: 1.302999\n",
      "\t Epoch 18. Accuracy: 0.634173 | Loss: 1.277790\n",
      "\t Epoch 19. Accuracy: 0.647275 | Loss: 1.226866\n",
      "\t Epoch 20. Accuracy: 0.652142 | Loss: 1.226686\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.651832\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 21. Accuracy: 0.646773 | Loss: 1.219855\n",
      "\t Epoch 22. Accuracy: 0.663006 | Loss: 1.195204\n",
      "\t Epoch 23. Accuracy: 0.670526 | Loss: 1.165726\n",
      "\t Epoch 24. Accuracy: 0.676121 | Loss: 1.148621\n",
      "\t Epoch 25. Accuracy: 0.687536 | Loss: 1.109896\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.690590\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 26. Accuracy: 0.690961 | Loss: 1.110705\n",
      "\t Epoch 27. Accuracy: 0.674958 | Loss: 1.158181\n",
      "\t Epoch 28. Accuracy: 0.696233 | Loss: 1.091395\n",
      "\t Epoch 29. Accuracy: 0.699146 | Loss: 1.077446\n",
      "\t Epoch 30. Accuracy: 0.710076 | Loss: 1.061827\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.716664\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 31. Accuracy: 0.714624 | Loss: 1.040824\n",
      "\t Epoch 32. Accuracy: 0.709445 | Loss: 1.044048\n",
      "\t Epoch 33. Accuracy: 0.719751 | Loss: 1.021937\n",
      "\t Epoch 34. Accuracy: 0.722247 | Loss: 1.002287\n",
      "\t Epoch 35. Accuracy: 0.725602 | Loss: 0.985023\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.717073\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 36. Accuracy: 0.721840 | Loss: 0.990956\n",
      "\t Epoch 37. Accuracy: 0.726419 | Loss: 0.978770\n",
      "\t Epoch 38. Accuracy: 0.730414 | Loss: 0.945205\n",
      "\t Epoch 39. Accuracy: 0.733045 | Loss: 0.931660\n",
      "\t Epoch 40. Accuracy: 0.743858 | Loss: 0.914782\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.724982\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 41. Accuracy: 0.751916 | Loss: 0.876523\n",
      "\t Epoch 42. Accuracy: 0.737594 | Loss: 0.914662\n",
      "\t Epoch 43. Accuracy: 0.735986 | Loss: 0.923208\n",
      "\t Epoch 44. Accuracy: 0.752869 | Loss: 0.868417\n",
      "\t Epoch 45. Accuracy: 0.753095 | Loss: 0.867506\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.740716\n",
      "----------------------------\n",
      "\n",
      "\t Epoch 46. Accuracy: 0.755373 | Loss: 0.851085\n",
      "\t Epoch 47. Accuracy: 0.755981 | Loss: 0.842593\n",
      "\t Epoch 48. Accuracy: 0.768917 | Loss: 0.813079\n",
      "\t Epoch 49. Accuracy: 0.761222 | Loss: 0.829013\n",
      "\n",
      "----------------------------\n",
      "Test Accuracy: 0.754173\n",
      "----------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "    train(model_nodp, criterion, optimizer_nodp, train_loader, epoch, device=device)\n",
    "    if test_every:\n",
    "        if epoch % test_every == 0:\n",
    "            test(model_nodp, test_loader, None, device=device)\n",
    "\n",
    "test(model_nodp, test_loader, None, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We run the training loop again, this time without privacy and for the same number of iterations. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The non-private classifier obtains a test accuracy of around 0.75 with the same parameters and number of epochs. We are effectively trading off performance on the name classification task for a lower loss of privacy."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
